{% from 'macros.jinja' import node_llm_config_to_dspy_lm, escape_prompt %}

import dspy
from typing import Dict, Any, List, Optional
from langwatch_nlp.studio.dspy import LLMNode, TemplateAdapter

{% set class_name = component.name or "Anonymous" %}

{% set ns = namespace(decorator_node=None) %}
{% for node in workflow.nodes %}
    {% if node.id == prompting_technique.ref %}
        {% set ns.decorator_node = node %}
{{ PROMPTING_TECHNIQUES[node.data.cls]["import"] }}
    {% endif %}
{% endfor %}

{% for key, value in json_schema_types.items() %}
{{ value["code"] }}
{% endfor %}

class {{ class_name }}Signature(dspy.Signature):
    {% if parameters.get('instructions', '') %}
    {{ escape_prompt(parameters.get('instructions', '')) }}

    {% endif %}
    {% if parameters.get('messages', '') %}
    _messages = [
        {% for message in parameters.get('messages', []) %}
        {"role": "{{ message.role }}", "content": {{ escape_prompt(message.content) }}},
        {% endfor %}
    ]
    {% endif %}

    {% for input_field in component.inputs or [] %}
    {{ input_field.identifier }}: {{ FIELD_TYPE_TO_DSPY_TYPE[input_field.type.value] }} = dspy.InputField()
    {% endfor %}
    {% for output_field in component.outputs or [] %}
    {% if output_field.type == "json_schema" %}
    {{ output_field.identifier }}: {{ json_schema_types[output_field.identifier]["model_name"] }} = dspy.OutputField()
    {% else %}
    {{ output_field.identifier }}: {{ FIELD_TYPE_TO_DSPY_TYPE[output_field.type.value] }} = dspy.OutputField()
    {% endif %}
    {% endfor %}
    {% if not component.inputs and not component.outputs %}
    pass
    {% endif %}


class {{ class_name }}(LLMNode):
    def __init__(self):
        {% if prompting_technique %}
        {# Apply prompting technique decorator #}
        {% if ns.decorator_node %}
        predict = {{ PROMPTING_TECHNIQUES[ns.decorator_node.data.cls]["class"] }}({{ class_name }}Signature)
        {% else %}
        {# Decorator node {{ prompting_technique.ref }} not found #}
        predict = dspy.Predict({{ class_name }}Signature)
        {% endif %}
        {% else %}
        {# Standard prediction #}
        predict = dspy.Predict({{ class_name }}Signature)
        {% endif %}

        {# Configure LLM #}
        {% if llm_config %}
        lm = {{ node_llm_config_to_dspy_lm(llm_config) }}
        {% endif %}

        {# Process demonstrations if available #}
        {% if demonstrations %}
        demos = [
            {% for d in demonstrations %}
                dspy.Example({% for k, v in d.items() %}{{ k }}={{ v.__repr__() }}, {% endfor %}),
            {% endfor %}
        ]
        {% endif %}

        super().__init__(
            node_id="{{ node_id }}",
            name="{{ class_name }}",
            predict=predict,
            {% if llm_config %}
            lm=lm,
            {% endif %}
            {% if demonstrations %}
            demos=demos
            {% endif %}
        )

    {% set input_args = [] %}
    {% set forward_args = [] %}
    {% for input_field in component.inputs or [] %}
        {% set _ = input_args.append(input_field.identifier + ": " + FIELD_TYPE_TO_DSPY_TYPE[input_field.type.value] + " = None") %}
        {% set _ = forward_args.append(input_field.identifier + "=" + input_field.identifier) %}
    {% endfor %}

    def forward(self, {{ input_args | join(", ") }}):
        return super().forward({{ forward_args | join(", ") }})
